15322
8665
Quelle est la manière la plus simple de transformer un tenseur de forme (batch_size, height, width) rempli de n valeurs en tenseur de forme (batch_size, n, height, width)?
J'ai créé la solution ci-dessous, mais il semble qu'il existe un moyen plus simple et plus rapide de le faire
def batch_tensor_to_onehot (tnsr, classes):
tnsr = tnsr.unsqueeze (1)
res = []
pour cls in range (classes):
res.append ((tnsr == cls) .long ())
retour torch.cat (res, dim = 1) 
Vous pouvez utiliser torch.nn.functional.one_hot.
Pour votre cas:
a = torch.nn.functional.one_hot (tnsr, num_classes = classes)
out = a.permute (0, 3, 1, 2)
|
Vous pouvez également utiliser Tensor.scatter_ qui évite .permute mais est sans doute plus difficile à comprendre que la méthode simple proposée par @Alpha.
def batch_tensor_to_onehot (tnsr, classes):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
résultat de retour
Résultats d'analyse comparative
J'étais curieux et j'ai décidé de comparer les trois approches. J'ai trouvé qu'il ne semble pas y avoir de différence relative significative entre les méthodes proposées en ce qui concerne la taille, la largeur ou la hauteur du lot. Principalement le nombre de classes était le facteur distinctif. Bien sûr, comme pour tout autre repère, le kilométrage peut varier.
Les repères ont été collectés en utilisant des indices aléatoires et en utilisant la taille du lot, la hauteur, la largeur = 100. Chaque expérience a été répétée 20 fois avec la moyenne rapportée. L'expérience num_classes = 100 est exécutée une fois avant le profilage pour le préchauffage.
Les résultats du processeur montrent que la méthode originale était probablement la meilleure pour num_classes inférieure à environ 30, tandis que pour le GPU, l'approche scatter_ semble être la plus rapide.
Tests effectués sur Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K
Le code utilisé pour l'analyse comparative est fourni ci-dessous:
importation de la torche
depuis tqdm import tqdm
temps d'importation
importer matplotlib.pyplot comme plt
def batch_tensor_to_onehot_slavka (tnsr, classes):
tnsr = tnsr.unsqueeze (1)
res = []
pour cls in range (classes):
res.append ((tnsr == cls) .long ())
retour torche.cat (res, dim = 1)
def batch_tensor_to_onehot_alpha (tnsr, classes):
result = torch.nn.functional.one_hot (tnsr, num_classes = classes)
return result.permute (0, 3, 1, 2)
def batch_tensor_to_onehot_jodag (tnsr, classes):
result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
résultat de retour
def main ():
num_classes = [2, 10, 25, 50, 100]
hauteur = 100
largeur = 100
bs = [100] x 20
pour d dans ['cpu', 'cuda']:
times_slavka = []
times_alpha = []
times_jodag = []
échauffement = Vrai
pour c dans tqdm ([num_classes [-1]] + num_classes, ncols = 0):
tslavka = 0
talpha = 0
tjodag = 0
pour b en bs:
tnsr = torch.randint (c, (b, hauteur, largeur)). à (appareil = d)
t0 = temps.heure ()
y = batch_tensor_to_onehot_slavka (tnsr, c)
torch.cuda.synchronize ()
tslavka + = temps.heure () - t0
sinon échauffement:
times_slavka.append (tslavka / len (bs))
pour b en bs:
tnsr = torch.randint (c, (b, hauteur, largeur)). à (appareil = d)
t0 = temps.heure ()
y = batch_tensor_to_onehot_alpha (tnsr, c)
torch.cuda.synchronize ()
talpha + = temps.heure () - t0
sinon échauffement:
times_alpha.append (talpha / len (bs))
pour b en bs:
tnsr = torch.randint (c, (b, hauteur, largeur)). à (appareil = d)
t0 = heure.heure ()
y = batch_tensor_to_onehot_jodag (tnsr, c)
torch.cuda.synchronize ()
tjodag + = time.time () - t0
sinon échauffement:
times_jodag.append (tjodag / len (bs))
échauffement = Faux
fig = plt.figure ()
ax = fig.subplots ()
ax.plot (num_classes, times_slavka, label = 'Slavka-cat')
ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot')
ax.plot (num_classes, times_jodag, label = 'jodag-scatter_')
ax.set_xlabel ('num_classes')
ax.set_ylabel ('heure (s)')
ax.set_title (f '{d} benchmark')
ax.legend ()
plt.savefig (f '{d} .png')
plt.show ()
si __name__ == "__main__":
principale()
|
Ta Réponse
StackExchange.ifUsing ("éditeur", fonction () {
StackExchange.using ("externalEditor", function () {
StackExchange.using ("extraits", function () {
StackExchange.snippets.init ();
});
});
}, "extraits de code");
StackExchange.ready (fonction () {
var channelOptions = {
tags: "" .split (""),
id: "1"
};
initTagRenderer ("". split (""), "" .split (""), channelOptions);
StackExchange.using ("externalEditor", function () {
// Doit lancer l'éditeur après les extraits, si les extraits sont activés
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using ("extraits", function () {
createEditor ();
});
}
autre {
createEditor ();
}
});
function createEditor () {
StackExchange.prepareEditor ({
useStacksEditor: faux,
heartbeatType: 'réponse',
autoActivateHeartbeat: faux,
convertImagesToLinks: vrai,
noModals: vrai,
showLowRepImageUploadWarning: vrai,
reputationToPostImages: 10,
bindNavPrevention: vrai,
suffixe: "",
imageUploader: {
brandingHtml: "Powered by \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.61182 46.72556.9762.6943 4.61182C47.4335 4.61182 46.72554.91628 46.094 4.68.4335 4.61182 46.72554.91628 46.094.49.48.4335 4.61182 46.72554.91628 46.094.49.48.4335 4.61182 46.7256 4.9762.698 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C62 41.5985 10.1419V6.59049C6221.59854 4.63932C41.59854 4.63932C4.63932C62 38,5948 5,28821 38,5948 6,59049V9,60062C38,5948 10,8521 38,2696 11,5455 37,0451 11,5455C35,8209 11,5455 35,4954 10,8521 35,4954 9,60062V6,59049C35,4954 5,28821 35,0173 4,66232,35,4954 5,28821 35,0173 4,66232 34,9 u003 4,692 5,28821 35,0173 4,66232/34,9 u003 4,692c 4,692C fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 4.68821 30.9633 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.913 25.3759.54 11.9852 23.003 13.913 25.3759.4 12.9852 23.003 13.913 25.3759.4 12.9852 23.003 13.913 25.3759.4 12.9852 23.003 13.913 25.3759.4 12.9134 12.716.1139.28.710C 2812710.4 13.9139.1 2812710.4 C28. 1256 12.8854 28,1301 12,9342 28,1301 14,4373 27,2502 15,2321 12.983C28.1301 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 22,8472 14,5218 14.7787C23.176 14,6393 22,5437 21,2429 15,0123 14.5218C21.7977 14,5218 21,2429 22,9072 17,6335 15.6887C21.2429 16,7375 25,6622 17.6335ZM24.1317 9,27932 C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27932Z \ "/" M u0016.80 = 24.1317 9.27932Z \ "/" M u0016.80 =. 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 5.88013C12.653 5.0580 4.666 4.828.96.48.48.375 13.3609 5.88013C12.653 5.0580 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.43458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.43458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.43458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.43458 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821 0.31335584 139049V 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846 3.576772 3.8619 3.57209 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e",
contentPolicyHtml: "Contributions des utilisateurs sous licence \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (politique de contenu) \ u003c / a \ u003e",
allowUrls: vrai
},
onDemand: vrai,
discardSelector: ".discard-answer"
, immédiatementShowMarkdownHelp: true, enableTables: true, enableSnippets: true
});
}
});
Merci d'avoir répondu à Stack Overflow!
Veuillez vous assurer de répondre à la question. Fournissez des détails et partagez vos recherches!
Mais évitez…
Demander de l'aide, des éclaircissements ou répondre à d'autres réponses.
Faire des déclarations basées sur des opinions; les sauvegarder avec des références ou une expérience personnelle.
Pour en savoir plus, consultez nos conseils sur la rédaction de bonnes réponses.
Brouillon enregistré
Brouillon rejeté
Inscrivez-vous ou connectez-vous
StackExchange.ready (fonction () {
StackExchange.helpers.onClickDraftSave ('# login-link');
});
Inscrivez-vous avec Google
Inscrivez-vous via Facebook
Inscrivez-vous par e-mail et mot de passe
Nous faire parvenir
Publier en tant qu'invité
Nom
E-mail
Obligatoire, mais jamais affiché
StackExchange.ready (
fonction () {
StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' );
}
);
Publier en tant qu'invité
Nom
E-mail
Obligatoire, mais jamais affiché
Publiez votre réponse
Jeter
En cliquant sur «Publier votre réponse», vous acceptez nos conditions d'utilisation, notre politique de confidentialité et notre politique de cookies
Ce n'est pas la réponse que vous recherchez? Parcourez les autres questions marquées python pytorch tensor one-hot-encoding ou posez votre propre question.